from build import mnistModel
from train import getData


def main():
    # 加载数据集
    train_ds, val_ds, num_classes, img_width, img_height = getData()

    checkpoint_path = 'checkout/'
    log_path = 'log'
    model_path = 'model/model.h5'

    # 构建模型
    model = mnistModel(checkpoint_path, log_path, model_path, num_classes, img_width, img_height)
    # 编译模型
    model.compile()
    # 训练模型
    model.train(train_ds, val_ds)
    # 评估模型
    test_loss, test_acc = model.evaluate(val_ds)
    print(test_loss, test_acc)


if __name__ == '__main__':
    main()
